package org.zjvis.datascience.service.dag;

import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;
import org.zjvis.datascience.common.dto.TaskInstanceDTO;
import org.zjvis.datascience.common.enums.AlgPyEnum;
import org.zjvis.datascience.common.enums.TaskTypeEnum;
import org.zjvis.datascience.common.util.RestTemplateUtil;
import org.zjvis.datascience.service.TaskInstanceService;
import reactor.core.publisher.Flux;

import java.util.List;
import java.util.concurrent.Callable;

/**
 * @description Flask任务调度器
 * @date 2021-12-24
 */
public class FlaskSubmitRunner implements Callable<TaskRunnerResult> {

    private final static Logger logger = LoggerFactory.getLogger(FlaskSubmitRunner.class);

    private String errorTpl = "{\"status\":500, \"error_msg\":\"%s\"}";
    private String emptyTpl = "{\"status\":0, \"error_msg\":\"%s\"}";

    private TaskInstanceDTO instance;

    private RestTemplateUtil restTemplateUtil;

    private TaskInstanceService taskInstanceService;

    public FlaskSubmitRunner(RestTemplateUtil restTemplateUtil,
                             TaskInstanceService taskInstanceService, TaskInstanceDTO instance) {
        this.restTemplateUtil = restTemplateUtil;
        this.instance = instance;
        this.taskInstanceService = taskInstanceService;
    }

    @Override
    public TaskRunnerResult call() throws Exception {
        TaskRunnerResult result = null;
        if (instance.hasPrecautionaryError()) {
            return new TaskRunnerResult(500, String.format(errorTpl, "error happens when init stage."));
        }
        try {
            if (instance.getType().equals(TaskTypeEnum.TASK_TYPE_ALGOPY.getVal())) {
                // webManagementService.submitFlaskJob(formData, instance);
                this.syncInAndOutput();
                result = this.submit();
            } else if (instance.getType().equals(TaskTypeEnum.TASK_TYPE_MODEL.getVal())) {
                //        mlModelService.exec(appArgs); ML model
                result = this.exec();
            }

        } catch (Exception e) {
            result = new TaskRunnerResult(500,
                    String.format(errorTpl, e.getMessage().replaceAll("\"", "'")));
        }
        return result;
    }

    private void syncInAndOutput() {
        JSONObject jsonObject = JSONObject.parseObject(this.instance.getDataJson());
        JSONObject inputInfo = jsonObject.getJSONObject("inputInfo");
        List<String> inputCols = inputInfo.getJSONArray("input").getJSONObject(0)
                .getJSONArray("tableCols").toJavaList(String.class);
        inputInfo.getJSONArray("output").getJSONObject(0).put("tableCols", inputCols);

        List<String> columnTypes = inputInfo.getJSONArray("input").getJSONObject(0)
                .getJSONArray("columnTypes").toJavaList(String.class);
        inputInfo.getJSONArray("output").getJSONObject(0).put("columnTypes", columnTypes);
        jsonObject.put("inputInfo", inputInfo);
        instance.setDataJson(jsonObject.toJSONString());
    }

    /**
     * 针对有sql语句 或者 appArgs 命令的任务
     *
     * @return
     */
    private TaskRunnerResult exec() {
        //        String appArgs = instance.getSqlText();
//        //在最后拼上taskIns的id参数，算子需要该参数将输出元数据写入mysql
//        if (StringUtils.isEmpty(appArgs)) {
//            appArgs = Long.toString(instance.getId());
//        } else {
//            appArgs = appArgs + " " + instance.getId();
//        }
//        String applicationId = null;
//        try {
//            applicationId = restTemplateUtil.submitPySparkJob(appArgs, appResourcePath);
//        } catch (IOException e) {
//            return new TaskRunnerResult(0, String.format(emptyTpl, "job probably not related to spark. Or there are some problem about the connection."));
//        }
//
//        if (StringUtils.isEmpty(applicationId)) {
//            return new TaskRunnerResult(500, String.format(errorTpl, "job submit fail"));
//        }
//        instance.setApplicationId(applicationId);
//        taskInstanceService.update(instance);
//        appIdInstanceMap.put(applicationId, instance);
//
//        String output = "";
//        int status = 0;
//        while (true) {
//            try {
//                Thread.sleep(1000);
//            } catch (InterruptedException e) {
//
//            }
//            JobStatusVO jobStatusVO = restTemplateUtil
//                    .queryJobStatus(applicationId, instance.getLogInfo());
//
//            if (null == jobStatusVO || StringUtils.isEmpty(jobStatusVO.getState())) {
//                continue;
//            }
//            //未结束的任务跳过
//            if (!JobStatus.jobIsEnd(jobStatusVO.getState())) {
//                continue;
//            }
//            //成功后更新
//            if (JobStatus.SUCCEEDED.toString().equals(jobStatusVO.getFinalStatus())) {
//                try {
//                    Thread.sleep(5000);
//                } catch (InterruptedException e) {
//                    e.printStackTrace();
//                }
//                TaskInstanceDTO instanceNew = taskInstanceService.queryById(instance.getId());
//                if (StringUtils.isNotEmpty(instanceNew.getLogInfo())) {
//                    output = instanceNew.getLogInfo();
//                    instance.setLogInfo(output);
//                    status = JSONObject.parseObject(output).getInteger("status");
//                }
//                instance.setProgress(100);
//                appIdInstanceMap.remove(applicationId);
//                break;
//                //失败后更新
//            } else if (JobStatus.FAILED.toString().equals(jobStatusVO.getFinalStatus())) {
//                status = 500;
//                output = String.format(Constant.errorTpl, jobStatusVO.getDiagnostics()
//                        .replaceAll("\"", "'"));
//                instance.setProgress(100);
//                appIdInstanceMap.remove(applicationId);
//                break;
//                //停止后更新
//            } else if (JobStatus.KILLED.toString().equals(jobStatusVO.getFinalStatus())) {
//                status = 500;
//                output = String.format(Constant.errorTpl, jobStatusVO.getDiagnostics()
//                        .replaceAll("\"", "'"));
//                instance.setProgress(100);
//                appIdInstanceMap.remove(applicationId);
//                break;
//            }
//        }
//
//        return new TaskRunnerResult(status, output);
//        //TODO AY
//        instance.setStatus(TaskInstanceStatus.RUNNING.toString());
//        //TaskDTO task = taskService.queryById(taskId);
//        TaskDTO task = taskService.queryById(taskInstanceService.queryById(instance.getId()).getTaskId());
//        Long taskId = task.getId();
//        JSONObject dataJson = JSONObject.parseObject(task.getDataJson());
//        JSONArray outputs = dataJson.getJSONArray("output");
//        JSONObject output = outputs.getJSONObject(0);
//        Long modelId = dataJson.getLong("modelId");
//        String target = "pipeline." + "ml_" + task.getUserId() + "_";
//        long timeStampModel = System.currentTimeMillis();
//        target += task.getId();
//        dagScheduler.setLastTimeStamp(task, timeStampModel);
//        taskService.update(task);
//        String featureX = dataJson.getString("feature_X");
//        JSONArray inputs = dataJson.getJSONArray("input");
//        JSONObject input = inputs.getJSONObject(0);
//        String source = input.getString("tableName");
//        Character lastChar = source.charAt(source.length() - 1);
//        if (lastChar.equals('_')) {
//            Long parentTimeStamp = dataJson.getJSONArray("parentTimeStamps").getLong(0);
//            source = source + parentTimeStamp;
//        }
//        //String appArgs = params.getString("param");
//        Long sparktime = mlModelService.queryMetricsById(modelId).getSparktime();
//        if (null == sparktime || sparktime.equals(0L)) {
//            //FLASK
//            try {
//                JSONObject params = new JSONObject();
//                params.put("apiPath", "ml_model");
//                params.put("model_id", modelId);
//                params.put("execution", "predict");
//                params.put("target", target);
//                params.put("source", source);
//                params.put("feature_X", featureX);
//                params.put("taskId", taskId);
//                String res = mlModelService.submitFlaskJob(params);
//            } catch (IOException e) {
//                e.printStackTrace();
//            }
//        } else {
//            //SPARK
//            String appArgs = "--id " + modelId + " --target " + target + " --exe predict"
//                    + " --source " + source + " --feature_col " + featureX
//                    + " --task_id " + taskId + " --instance_id " + 100;
//            mlModelService.exec(appArgs);
//        }
//
//
//        JSONObject newParam = dataJson.getJSONObject("param");
//        newParam.put("execution", "predict");
//        newParam.put("feature_X", featureX);
//        newParam.put("source", output.getString("tableName"));
//        newParam.put("model_saved_path", dataJson.getString("modelPath"));
//        newParam.put("model_id", dataJson.getLong("modelId"));
//        newParam.put("target", target);
//        //instance.setStatus(TaskInstanceStatus.SUCCESS.toString());
//        instance.setProgress(100);
//        instance.setDataJson(newParam.toJSONString());
        return null;
    }

    /**
     * @return
     */
    public TaskRunnerResult submit() throws JsonProcessingException {
        Long taskId = instance.getTaskId();
        logger.warn("---------> {} ", taskId);
        JSONObject dataJson = JSONObject.parseObject(instance.getDataJson())
                .getJSONObject("inputInfo");
        JSONObject formData = dataJson.getJSONArray("setParams").getJSONObject(0)
                .getJSONObject("formData");

        logger.warn("---------> dataJson: {} ", dataJson);
        JSONObject input = dataJson.getJSONArray("input").getJSONObject(0);
        String source = input.getString("tableName");
        if (source.endsWith("_")) {
            Long parentTimeStamp = dataJson.getJSONArray("parentTimeStamps").getLong(0);
            source = source + parentTimeStamp;
        }
        int algoVal = dataJson.getInteger("algType");
        String algoEngName = AlgPyEnum.getEnumByVal(algoVal).getName();

        formData.put("taskId", taskId.toString());
        formData.put("target",
                String.format("pipeline.%s_%d_%d", algoEngName, taskId, System.currentTimeMillis()));
        formData.put("source", source);
        formData.put("instanceId", instance.getId().toString());

        if (((algoEngName.equals(AlgPyEnum.SIMULATE.getName()) || algoEngName
                .equals(AlgPyEnum.SIMULATE_NEW.getName())) && formData.getString("tih") == null) ||
                (algoEngName.equals(AlgPyEnum.SIMULATE_INVERSE.getName()) && formData.getString("gdp") == null)) {
            for (Object obj : formData.getJSONArray("setParams")) {
                JSONObject item = (JSONObject) obj;
                String name = item.getString("english_name");
                formData.put(name, item.getString("value"));
            }
        }

        String flaskServer = restTemplateUtil.getFlaskServer();
        String url = String.format("%s/%s", flaskServer, algoEngName);
        logger.warn("---------> url: {} ", url);
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);

        MultiValueMap<String, String> formDataMap = new LinkedMultiValueMap<String, String>();
        for (String keyStr : formData.keySet()) {
            formDataMap.add(keyStr, formData.getString(keyStr));
        }
        logger.warn("formData -> {}", formDataMap.toString());
        WebClient webClient = WebClient.create();
        Flux<String> flux = webClient.post().uri(url)
                .body(BodyInserters.fromFormData(formDataMap)).retrieve().bodyToFlux(String.class);

        JSONObject resJson = JSONObject.parseObject(flux.blockFirst());
        boolean error = false;
        if (resJson.containsKey("error_msg")) {
            instance.setLogInfo(String.format(errorTpl, resJson.getString("error_msg")));
            error = true;
        } else {
            instance.setLogInfo(resJson.toJSONString());
        }
        instance.setProgress(100);
        taskInstanceService.update(instance);
        logger.warn("---------> instance: {} ", instance.toString());
        return error ? TaskRunnerResult.fail(resJson.getString("error_msg"))
                : TaskRunnerResult.ok(resJson.toJSONString());
    }

}
